// Copyright (c) 2021 Terminus, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"os"
	"strings"
	"text/template"
)

// KeyInfo represents a mapKey definition extracted from AST
type KeyInfo struct {
	Name          string // e.g., "Client"
	MapKeyName    string // e.g., "mapKeyClient"
	Type          string // e.g., "*clientpb.Client"
	HasCustomMust bool   // Whether MustGet has custom implementation
}

// customMustGetKeys is populated from KeysWithCustomMustGet via AST parsing
var customMustGetKeys = map[string]bool{}

// debug logging - try both stdout and stderr to ensure visibility during `go generate`
func logf(format string, args ...any) {
	fmt.Printf(format, args...)
	fmt.Fprintf(os.Stderr, format, args...)
}
func logln(s string) {
	fmt.Println(s)
	fmt.Fprintln(os.Stderr, s)
}

const generatedTemplate = `// Code generated by go generate; DO NOT EDIT.
// This file was automatically generated from keys.go

package ctxhelper

import (
	"context"
{{range .Imports}}
	{{.}}
{{end}}
)

{{range .Keys}}
// Get{{.Name}} retrieves {{.Name}} from context
func Get{{.Name}}(ctx context.Context) ({{.Type}}, bool) {
	return getFromMapKeyAs[{{.Type}}](ctx, {{.MapKeyName}}{})
}

// Put{{.Name}} stores {{.Name}} in context
func Put{{.Name}}(ctx context.Context, value {{.Type}}) {
	putToMapKey(ctx, {{.MapKeyName}}{}, value)
}

{{if not .HasCustomMust}}
// MustGet{{.Name}} retrieves {{.Name}} from context, panics if not found
func MustGet{{.Name}}(ctx context.Context) {{.Type}} {
	value, ok := Get{{.Name}}(ctx)
	if !ok {
		panic("{{.Name}} not found in context")
	}
	return value
}
{{end}}

{{end}}
`

// extractFunctionName extracts function name from mapKeyXxx
func extractFunctionName(mapKeyName string) string {
	if strings.HasPrefix(mapKeyName, "mapKey") {
		name := strings.TrimPrefix(mapKeyName, "mapKey")
		if len(name) > 0 {
			return strings.ToUpper(name[:1]) + name[1:]
		}
	}
	return mapKeyName
}

// extractTypeFromStruct extracts the type from struct field
func extractTypeFromStruct(structType *ast.StructType) (string, error) {
	if len(structType.Fields.List) != 1 {
		return "", fmt.Errorf("expected exactly one field in struct")
	}

	field := structType.Fields.List[0]
	return extractTypeString(field.Type), nil
}

// extractTypeString converts ast.Expr to type string
func extractTypeString(expr ast.Expr) string {
	switch t := expr.(type) {
	case *ast.Ident:
		return t.Name
	case *ast.StarExpr:
		return "*" + extractTypeString(t.X)
	case *ast.SelectorExpr:
		return extractTypeString(t.X) + "." + t.Sel.Name
	case *ast.InterfaceType:
		if len(t.Methods.List) == 0 {
			return "any"
		}
	}
	return "unknown"
}

// extractImportsFromFile extracts import statements from the file
func extractImportsFromFile(file *ast.File) map[string]string {
	imports := make(map[string]string)

	for _, imp := range file.Imports {
		path := strings.Trim(imp.Path.Value, "\"")
		var alias string
		if imp.Name != nil {
			alias = imp.Name.Name
		} else {
			// Extract package name from path
			parts := strings.Split(path, "/")
			alias = parts[len(parts)-1]
		}
		imports[alias] = fmt.Sprintf("%s \"%s\"", alias, path)
	}

	return imports
}

// determineRequiredImports determines which imports are needed for the generated types
func determineRequiredImports(keys []KeyInfo, fileImports map[string]string) []string {
	var imports []string
	importSet := make(map[string]bool)

	for _, key := range keys {
		// Extract package prefix from type (e.g., "clientpb" from "*clientpb.Client")
		typeStr := key.Type
		if strings.Contains(typeStr, ".") {
			parts := strings.Split(typeStr, ".")
			if len(parts) >= 2 {
				pkg := strings.TrimPrefix(parts[0], "*")
				if importStr, exists := fileImports[pkg]; exists && !importSet[importStr] {
					imports = append(imports, importStr)
					importSet[importStr] = true
				}
			}
		}
	}

	return imports
}

// isSkippedKey determines if a key should be skipped based on its struct definition
func isSkippedKey(mapKeyName string, structType *ast.StructType) bool {
	// Skip if it's the old format (struct{ mapKeyXxx any })
	if len(structType.Fields.List) == 1 {
		field := structType.Fields.List[0]
		if len(field.Names) == 1 && field.Names[0].Name == mapKeyName {
			// This is old format: struct{ mapKeyXxx any }
			return true
		}
	}
	return false
}

// hasCustomMustGet checks if a key has custom MustGet implementation
func hasCustomMustGet(mapKeyName string) bool { return customMustGetKeys[mapKeyName] }

// initCustomMustGetFromVar enumerates KeysWithCustomMustGet keys and collects their type names
func extractCustomMustGetKeys(file *ast.File) map[string]bool {
	logln("[gen] scanning KeysWithCustomMustGet via AST...")
	result := map[string]bool{}
	for _, decl := range file.Decls {
		genDecl, ok := decl.(*ast.GenDecl)
		if !ok || genDecl.Tok != token.VAR {
			continue
		}
		for _, spec := range genDecl.Specs {
			vs, ok := spec.(*ast.ValueSpec)
			if !ok {
				continue
			}
			// find var named KeysWithCustomMustGet
			found := false
			for _, n := range vs.Names {
				if n.Name == "KeysWithCustomMustGet" {
					found = true
					break
				}
			}
			if !found || len(vs.Values) == 0 {
				continue
			}
			cl, ok := vs.Values[0].(*ast.CompositeLit)
			if !ok {
				continue
			}
			for _, elt := range cl.Elts {
				kv, ok := elt.(*ast.KeyValueExpr)
				if !ok {
					continue
				}
				if keyCL, ok := kv.Key.(*ast.CompositeLit); ok {
					if name := extractTypeString(keyCL.Type); name != "unknown" {
						logf("[gen]   custom key type: %s\n", name)
						result[name] = true
					}
				}
			}
		}
	}
	if len(result) == 0 {
		logln("[gen]   WARN: no custom MustGet keys detected")
	} else {
		var list []string
		for k := range result {
			list = append(list, k)
		}
		logf("[gen]   collected custom keys: %v\n", list)
	}
	return result
}

func main() {
	// Parse keys.go to extract mapKey definitions
	fset := token.NewFileSet()
	node, err := parser.ParseFile(fset, "./keys.go", nil, parser.ParseComments)
	if err != nil {
		fmt.Printf("Error parsing keys.go: %v\n", err)
		os.Exit(1)
	}

	// Build set of keys that have custom MustGet implementations by parsing AST
	customMustGetKeys = extractCustomMustGetKeys(node)

	var keys []KeyInfo
	fileImports := extractImportsFromFile(node)

	// Walk the AST to find type declarations
	ast.Inspect(node, func(n ast.Node) bool {
		if genDecl, ok := n.(*ast.GenDecl); ok && genDecl.Tok == token.TYPE {
			for _, spec := range genDecl.Specs {
				if typeSpec, ok := spec.(*ast.TypeSpec); ok {
					mapKeyName := typeSpec.Name.Name

					// Check if it's a struct type
					if structType, ok := typeSpec.Type.(*ast.StructType); ok {
						// Skip if it's the old format
						if isSkippedKey(mapKeyName, structType) {
							logf("[gen] Skipping old format: %s\n", mapKeyName)
							continue
						}

						// Extract type from struct field
						typeStr, err := extractTypeFromStruct(structType)
						if err != nil {
							logf("[gen] Warning: Failed to extract type for %s: %v\n", mapKeyName, err)
							continue
						}

						// Skip if it's any type in old format
						if typeStr == "any" && isSkippedKey(mapKeyName, structType) {
							logf("[gen] Skipping old format any type: %s\n", mapKeyName)
							continue
						}

						functionName := extractFunctionName(mapKeyName)

						keyInfo := KeyInfo{
							Name:          functionName,
							MapKeyName:    mapKeyName,
							Type:          typeStr,
							HasCustomMust: hasCustomMustGet(mapKeyName),
						}

						keys = append(keys, keyInfo)
						if keyInfo.HasCustomMust {
							logf("[gen] Found key: %s -> %s (%s) [custom MustGet]\n", mapKeyName, functionName, typeStr)
						} else {
							logf("[gen] Found key: %s -> %s (%s)\n", mapKeyName, functionName, typeStr)
						}
					}
				}
			}
		}
		return true
	})

	if len(keys) == 0 {
		fmt.Println("No keys found for generation")
		os.Exit(1)
	}

	// Determine required imports
	imports := determineRequiredImports(keys, fileImports)

	// Generate the code
	tmpl, err := template.New("generated").Parse(generatedTemplate)
	if err != nil {
		fmt.Printf("Error parsing template: %v\n", err)
		os.Exit(1)
	}

	file, err := os.Create("./generated.go")
	if err != nil {
		fmt.Printf("Error creating generated.go: %v\n", err)
		os.Exit(1)
	}
	defer file.Close()

	data := struct {
		Keys    []KeyInfo
		Imports []string
	}{
		Keys:    keys,
		Imports: imports,
	}

	err = tmpl.Execute(file, data)
	if err != nil {
		fmt.Printf("Error executing template: %v\n", err)
		os.Exit(1)
	}

	fmt.Printf("Successfully generated functions for %d keys in generated.go\n", len(keys))
}
